{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a73a40df",
   "metadata": {},
   "source": [
    "# Watermarking custom audio\n",
    "\n",
    "[[`arXiv`](https://arxiv.org/abs/2401.17264)]\n",
    "[[`GitHub`](https://github.com/facebookresearch/audioseal)]\n",
    "\n",
    "This notebook shows a minimal example how to watermark a custom audio, for example your own recorded voice. This notebook aims to run in Google Collab. Make sure you get familiar with the APIs of AudioSeal, for example using [Getting Started notebook](./Getting_started.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c2562ce",
   "metadata": {},
   "source": [
    "## Installation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fbb4b36",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Install requirements\n",
    "!pip install torchaudio\n",
    "!pip install matplotlib\n",
    "!pip install audioseal # Ensure this matches the actual package name for AudioSeal\n",
    "!pip install ffmpeg-python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1325f7d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import io\n",
    "import ffmpeg\n",
    "import IPython.display as ipd\n",
    "from google.colab.output import eval_js\n",
    "\n",
    "from base64 import b64decode\n",
    "from scipy.io.wavfile import read as wav_read\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torchaudio\n",
    "\n",
    "from audioseal import AudioSeal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7f95544",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_waveform_and_specgram(waveform, sample_rate, title):\n",
    "    waveform = waveform.squeeze().detach().cpu().numpy()\n",
    "\n",
    "    num_frames = waveform.shape[-1]\n",
    "    time_axis = torch.arange(0, num_frames) / sample_rate\n",
    "\n",
    "    figure, (ax1, ax2) = plt.subplots(1, 2)\n",
    "\n",
    "    ax1.plot(time_axis, waveform, linewidth=1)\n",
    "    ax1.grid(True)\n",
    "    ax2.specgram(waveform, Fs=sample_rate)\n",
    "\n",
    "    figure.suptitle(f\"{title} - Waveform and specgram\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def play_audio(waveform, sample_rate):\n",
    "    if waveform.dim() > 2:\n",
    "        waveform = waveform.squeeze(0)\n",
    "    waveform = waveform.detach().cpu().numpy()\n",
    "\n",
    "    num_channels, *_ = waveform.shape\n",
    "    if num_channels == 1:\n",
    "        ipd.display(ipd.Audio(waveform[0], rate=sample_rate))\n",
    "    elif num_channels == 2:\n",
    "        ipd.display(ipd.Audio((waveform[0], waveform[1]), rate=sample_rate))\n",
    "    else:\n",
    "        raise ValueError(\"Waveform with more than 2 channels are not supported.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daf14d39",
   "metadata": {},
   "outputs": [],
   "source": [
    "AUDIO_HTML = \"\"\"\n",
    "<script>\n",
    "var my_div = document.createElement(\"DIV\");\n",
    "var my_p = document.createElement(\"P\");\n",
    "var my_btn = document.createElement(\"BUTTON\");\n",
    "var t = document.createTextNode(\"Press to start recording\");\n",
    "\n",
    "my_btn.appendChild(t);\n",
    "//my_p.appendChild(my_btn);\n",
    "my_div.appendChild(my_btn);\n",
    "document.body.appendChild(my_div);\n",
    "\n",
    "var base64data = 0;\n",
    "var reader;\n",
    "var recorder, gumStream;\n",
    "var recordButton = my_btn;\n",
    "\n",
    "var handleSuccess = function(stream) {\n",
    "  gumStream = stream;\n",
    "  var options = {\n",
    "    //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n",
    "    mimeType : 'audio/webm;codecs=opus'\n",
    "    //mimeType : 'audio/webm;codecs=pcm'\n",
    "  };\n",
    "  //recorder = new MediaRecorder(stream, options);\n",
    "  recorder = new MediaRecorder(stream);\n",
    "  recorder.ondataavailable = function(e) {\n",
    "    var url = URL.createObjectURL(e.data);\n",
    "    var preview = document.createElement('audio');\n",
    "    preview.controls = true;\n",
    "    preview.src = url;\n",
    "    document.body.appendChild(preview);\n",
    "\n",
    "    reader = new FileReader();\n",
    "    reader.readAsDataURL(e.data);\n",
    "    reader.onloadend = function() {\n",
    "      base64data = reader.result;\n",
    "      //console.log(\"Inside FileReader:\" + base64data);\n",
    "    }\n",
    "  };\n",
    "  recorder.start();\n",
    "  };\n",
    "\n",
    "recordButton.innerText = \"Recording... press to stop\";\n",
    "\n",
    "navigator.mediaDevices.getUserMedia({audio: true}).then(handleSuccess);\n",
    "\n",
    "\n",
    "function toggleRecording() {\n",
    "  if (recorder && recorder.state == \"recording\") {\n",
    "      recorder.stop();\n",
    "      gumStream.getAudioTracks()[0].stop();\n",
    "      recordButton.innerText = \"Saving the recording... pls wait!\"\n",
    "  }\n",
    "}\n",
    "\n",
    "// https://stackoverflow.com/a/951057\n",
    "function sleep(ms) {\n",
    "  return new Promise(resolve => setTimeout(resolve, ms));\n",
    "}\n",
    "\n",
    "var data = new Promise(resolve=>{\n",
    "//recordButton.addEventListener(\"click\", toggleRecording);\n",
    "recordButton.onclick = ()=>{\n",
    "toggleRecording()\n",
    "\n",
    "sleep(2000).then(() => {\n",
    "  // wait 2000ms for the data to be available...\n",
    "  // ideally this should use something like await...\n",
    "  //console.log(\"Inside data:\" + base64data)\n",
    "  resolve(base64data.toString())\n",
    "\n",
    "});\n",
    "\n",
    "}\n",
    "});\n",
    "\n",
    "</script>\n",
    "\"\"\"\n",
    "\n",
    "def get_audio():\n",
    "  display(ipd.HTML(AUDIO_HTML))\n",
    "  data = eval_js(\"data\")\n",
    "  binary = b64decode(data.split(',')[1])\n",
    "\n",
    "  process = (ffmpeg\n",
    "    .input('pipe:0')\n",
    "    .output('pipe:1', format='wav')\n",
    "    .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True, quiet=True, overwrite_output=True)\n",
    "  )\n",
    "  output, err = process.communicate(input=binary)\n",
    "\n",
    "  riff_chunk_size = len(output) - 8\n",
    "  # Break up the chunk size into four bytes, held in b.\n",
    "  q = riff_chunk_size\n",
    "  b = []\n",
    "  for i in range(4):\n",
    "      q, r = divmod(q, 256)\n",
    "      b.append(r)\n",
    "\n",
    "  # Replace bytes 4:8 in proc.stdout with the actual size of the RIFF chunk.\n",
    "  riff = output[:4] + bytes(b) + output[8:]\n",
    "\n",
    "  sr, audio = wav_read(io.BytesIO(riff))\n",
    "\n",
    "  return audio, sr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53826104",
   "metadata": {},
   "source": [
    "## Record your audio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3216ff1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "recorded, sr = get_audio()\n",
    "\n",
    "# Convert audio from list of int16 to a normalized tensor\n",
    "audio = torch.tensor(recorded).float() / 32768.0\n",
    "print(audio.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2110143",
   "metadata": {},
   "source": [
    "## Generator\n",
    "\n",
    "To watermark an audio, we simply load the watermarking generator from the hub:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "007c48cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AudioSeal.load_generator(\"audioseal_wm_16bits\")\n",
    "\n",
    "# We add the batch dimension to the single audio to mimic the batch watermarking\n",
    "audios = audio.unsqueeze(0).unsqueeze(0)  # b=1 c=1 t\n",
    "\n",
    "watermark = model.get_watermark(audios, sample_rate=sr)\n",
    "watermarked_audio = audios + watermark\n",
    "\n",
    "# Alternatively, you can also call forward() function directly with different tune-down / tune-up rate\n",
    "watermarked_audio = model(audios, sample_rate=sr, alpha=1)\n",
    "\n",
    "# You can also watermark with a secret message\n",
    "# secret_mesage = torch.randint(0, 2, (1, 16), dtype=torch.int32)\n",
    "# watermarked_audio = model(audios, sample_rate=sr, message=secret_mesage, alpha=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac5aac4f",
   "metadata": {},
   "source": [
    "We can see that the watermarked audio has preserved almost the same spectrogram and contents as the original one"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0200bc22",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_waveform_and_specgram(watermarked_audio.squeeze(), sr, title=\"Watermarked audio\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d70a0d73",
   "metadata": {},
   "outputs": [],
   "source": [
    "play_audio(watermarked_audio, sr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd7fa4fe",
   "metadata": {},
   "source": [
    "## Detector\n",
    "\n",
    "To detect the watermarks from an audio, we load the separate detector model and can do one of the following:\n",
    "\n",
    "### Basic usage: Call `detect_watermark()`\n",
    "\n",
    "This results in a tuple of form `Tuple(float, Tensor)`, where the first value indicates the probability of the audio being watermarked (the higher, the more likely), and the second value is the decoded message that is embeded by the generator. If the audio is unwatermarked (low first value), the decoded message will be just some random bits.\n",
    "\n",
    "Note that due to the stochastic nature of the detector, the decoded message and the secret message might miss by 1 bit, so depending on the user's need, the detection might be called multiple times to get an averaged decoded message."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b1a3a9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "detector = AudioSeal.load_detector((\"audioseal_detector_16bits\"))\n",
    "\n",
    "result, message = detector.detect_watermark(watermarked_audio, sample_rate=sr, message_threshold=0.5)\n",
    "\n",
    "print(f\"\\nThis is likely a watermarked audio. WM probability: {result}\")\n",
    "\n",
    "# Run on an unwatermarked audio\n",
    "result2, message2 = detector.detect_watermark(audios, sample_rate=sr, message_threshold=0.5)\n",
    "print(f\"This is likely an unwatermarked audio. WM probability: {result2}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7730364d",
   "metadata": {},
   "outputs": [],
   "source": [
    "message"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dc67150",
   "metadata": {},
   "source": [
    "`message_threshold` indicates the threshold in which the detector will convert the stochastic messages (with probability between 0 and 1) into the n-bit binary format. In most of the case, the generator generates an unbiased message from the secret, so `0.5` is a reasonable choice (so in the above example, value > 0.5 means 1 and value < 0.5 means 0). \n",
    "\n",
    "\n",
    "### Advanced usage: Call `forward()`\n",
    "\n",
    "The detector can also be called directly as a Torch module. This will return 2 tensors: \n",
    "- The first tensor of size `batch x 2 x frames` indicates the probability of each frame being watermarked (positive or negative). So t[:, 0, :] corresponds to the negative probability and t[:, 1, :] corresponds to the positive probability\n",
    "- The second tensor of size `batch x n_bits` corresponds to the message detected from the audio. It indicates the probability for each bit to be 1. In case of unwatermarked audios, this tensor is random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fadf26a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_prob, message_prob = detector(watermarked_audio, sample_rate=sr)\n",
    "pred_prob[:, 1, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "899de6b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "message_prob"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a78766fd",
   "metadata": {},
   "source": [
    "### Robustness against attacks\n",
    "\n",
    "We can evaluate the robustness of the detector against some attacks. For this purpose, we will perform some simple attacks: Pink noise, highpass filter, compression in different formats. For the full list of attacks, please refer to our paper. \n",
    "\n",
    "\n",
    "#### Pink noise attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cc0efde",
   "metadata": {},
   "outputs": [],
   "source": [
    "from attacks import AudioEffects as af\n",
    "\n",
    "pink_noised_audio = af.pink_noise(watermarked_audio, noise_std=0.1)\n",
    "plot_waveform_and_specgram(pink_noised_audio, sample_rate=sr, title=\"Audio with pink noise\")\n",
    "result, message = detector.detect_watermark(pink_noised_audio, sample_rate=sr)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2b96eb1",
   "metadata": {},
   "source": [
    "#### Lowpass filter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "254e6012",
   "metadata": {},
   "outputs": [],
   "source": [
    "lowpass_filtered = af.lowpass_filter(watermarked_audio, cutoff_freq=5000, sample_rate=sr)\n",
    "plot_waveform_and_specgram(lowpass_filtered, sample_rate=sr, title=\"Audio with low pass filtered\")\n",
    "result, message = detector.detect_watermark(lowpass_filtered, sample_rate=sr)\n",
    "print(result)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "awm-oss",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
